{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple RNN\n",
    "\n",
    "In this notebook, we're going to train a simple RNN to do **time-series prediction**. Given some set of input data, it should be able to generate a prediction for the next time step!\n",
    "<img src='assets/time_prediction.png' width=40% />\n",
    "\n",
    "> * First, we'll create our data\n",
    "* Then, define an RNN in PyTorch\n",
    "* Finally, we'll train our network and see how it performs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import resources and create data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(8,5))\n",
    "\n",
    "# how many time steps/data pts are in one batch of data\n",
    "seq_length = 20\n",
    "\n",
    "# generate evenly spaced data pts\n",
    "time_steps = np.linspace(0, np.pi, seq_length + 1)\n",
    "data = np.sin(time_steps)\n",
    "data.resize((seq_length + 1, 1)) # size becomes (seq_length+1, 1), adds an input_size dimension\n",
    "\n",
    "x = data[:-1] # all but the last piece of data\n",
    "y = data[1:] # all but the first\n",
    "\n",
    "# display the data\n",
    "plt.plot(time_steps[1:], x, 'r.', label='input, x') # x\n",
    "plt.plot(time_steps[1:], y, 'b.', label='target, y') # y\n",
    "\n",
    "plt.legend(loc='best')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Define the RNN\n",
    "\n",
    "Next, we define an RNN in PyTorch. We'll use `nn.RNN` to create an RNN layer, then we'll add a last, fully-connected layer to get the output size that we want. An RNN takes in a number of parameters:\n",
    "* **input_size** - the size of the input\n",
    "* **hidden_dim** - the number of features in the RNN output and in the hidden state\n",
    "* **n_layers** - the number of layers that make up the RNN, typically 1-3; greater than 1 means that you'll create a stacked RNN\n",
    "* **batch_first** - whether or not the input/output of the RNN will have the batch_size as the first dimension (batch_size, seq_length, hidden_dim)\n",
    "\n",
    "Take a look at the [RNN documentation](https://pytorch.org/docs/stable/nn.html#rnn) to read more about recurrent layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class RNN(nn.Module):\n",
    "    def __init__(self, input_size, output_size, hidden_dim, n_layers):\n",
    "        super(RNN, self).__init__()\n",
    "        \n",
    "        self.hidden_dim=hidden_dim\n",
    "\n",
    "        # define an RNN with specified parameters\n",
    "        # batch_first means that the first dim of the input and output will be the batch_size\n",
    "        self.rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True)\n",
    "        \n",
    "        # last, fully-connected layer\n",
    "        self.fc = nn.Linear(hidden_dim, output_size)\n",
    "\n",
    "    def forward(self, x, hidden):\n",
    "        # x (batch_size, seq_length, input_size)\n",
    "        # hidden (n_layers, batch_size, hidden_dim)\n",
    "        # r_out (batch_size, time_step, hidden_size)\n",
    "        batch_size = x.size(0)\n",
    "        \n",
    "        # get RNN outputs\n",
    "        r_out, hidden = self.rnn(x, hidden)\n",
    "        # shape output to be (batch_size*seq_length, hidden_dim)\n",
    "        r_out = r_out.view(-1, self.hidden_dim)  \n",
    "        \n",
    "        # get final output \n",
    "        output = self.fc(r_out)\n",
    "        \n",
    "        return output, hidden\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check the input and output dimensions\n",
    "\n",
    "As a check that your model is working as expected, test out how it responds to input data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test that dimensions are as expected\n",
    "test_rnn = RNN(input_size=1, output_size=1, hidden_dim=10, n_layers=2)\n",
    "\n",
    "# generate evenly spaced, test data pts\n",
    "time_steps = np.linspace(0, np.pi, seq_length)\n",
    "data = np.sin(time_steps)\n",
    "data.resize((seq_length, 1))\n",
    "\n",
    "test_input = torch.Tensor(data).unsqueeze(0) # give it a batch_size of 1 as first dimension\n",
    "print('Input size: ', test_input.size())\n",
    "\n",
    "# test out rnn sizes\n",
    "test_out, test_h = test_rnn(test_input, None)\n",
    "print('Output size: ', test_out.size())\n",
    "print('Hidden state size: ', test_h.size())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Training the RNN\n",
    "\n",
    "Next, we'll instantiate an RNN with some specified hyperparameters. Then train it over a series of steps, and see how it performs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# decide on hyperparameters\n",
    "input_size=1 \n",
    "output_size=1\n",
    "hidden_dim=32\n",
    "n_layers=1\n",
    "\n",
    "# instantiate an RNN\n",
    "rnn = RNN(input_size, output_size, hidden_dim, n_layers)\n",
    "print(rnn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loss and Optimization\n",
    "\n",
    "This is a regression problem: can we train an RNN to accurately predict the next data point, given a current data point?\n",
    "\n",
    ">* The data points are coordinate values, so to compare a predicted and ground_truth point, we'll use a regression loss: the mean squared error.\n",
    "* It's typical to use an Adam optimizer for recurrent models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# MSE loss and Adam optimizer with a learning rate of 0.01\n",
    "criterion = nn.MSELoss()\n",
    "optimizer = torch.optim.Adam(rnn.parameters(), lr=0.01) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Defining the training function\n",
    "\n",
    "This function takes in an rnn, a number of steps to train for, and returns a trained rnn. This function is also responsible for displaying the loss and the predictions, every so often.\n",
    "\n",
    "#### Hidden State\n",
    "\n",
    "Pay close attention to the hidden state, here:\n",
    "* Before looping over a batch of training data, the hidden state is initialized\n",
    "* After a new hidden state is generated by the rnn, we get the latest hidden state, and use that as input to the rnn for the following steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# train the RNN\n",
    "def train(rnn, n_steps, print_every):\n",
    "    \n",
    "    # initialize the hidden state\n",
    "    hidden = None      \n",
    "    \n",
    "    for batch_i, step in enumerate(range(n_steps)):\n",
    "        # defining the training data \n",
    "        time_steps = np.linspace(step * np.pi, (step+1)*np.pi, seq_length + 1)\n",
    "        data = np.sin(time_steps)\n",
    "        data.resize((seq_length + 1, 1)) # input_size=1\n",
    "\n",
    "        x = data[:-1]\n",
    "        y = data[1:]\n",
    "        \n",
    "        # convert data into Tensors\n",
    "        x_tensor = torch.Tensor(x).unsqueeze(0) # unsqueeze gives a 1, batch_size dimension\n",
    "        y_tensor = torch.Tensor(y)\n",
    "\n",
    "        # outputs from the rnn\n",
    "        prediction, hidden = rnn(x_tensor, hidden)\n",
    "\n",
    "        ## Representing Memory ##\n",
    "        # make a new variable for hidden and detach the hidden state from its history\n",
    "        # this way, we don't backpropagate through the entire history\n",
    "        hidden = hidden.data\n",
    "\n",
    "        # calculate the loss\n",
    "        loss = criterion(prediction, y_tensor)\n",
    "        # zero gradients\n",
    "        optimizer.zero_grad()\n",
    "        # perform backprop and update weights\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # display loss and predictions\n",
    "        if batch_i%print_every == 0:        \n",
    "            print('Loss: ', loss.item())\n",
    "            plt.plot(time_steps[1:], x, 'r.') # input\n",
    "            plt.plot(time_steps[1:], prediction.data.numpy().flatten(), 'b.') # predictions\n",
    "            plt.show()\n",
    "    \n",
    "    return rnn\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train the rnn and monitor results\n",
    "n_steps = 75\n",
    "print_every = 15\n",
    "\n",
    "trained_rnn = train(rnn, n_steps, print_every)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Time-Series Prediction\n",
    "\n",
    "Time-series prediction can be applied to many tasks. Think about weather forecasting or predicting the ebb and flow of stock market prices. You can even try to generate predictions much further in the future than just one time step!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda root]",
   "language": "python",
   "name": "conda-root-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
